eval_mcmc <- F
This notebook simulates several different sets of genealogies containing a single expnasion and tests MCMC inference against them, aiming to validate how well are parameters recovered across different scales.
This section validates how well are model parameters recovered across different population size scales and different expansion rates
set.seed(1)
burn_in <- 0.1
n_it <- 1e7
thinning <- n_it/1e4
repetitions <- 200
dims <- 1
lambdar_range <- c(5)
pop_range <- c(5)
dir.create(file.path(base_dir, "test_spec"))
dir.create(file.path(base_dir, "test_out"))
seed <- 0
for(i in c(1:length(lambdar_range))) {
dir_name <- paste0("lambdar_",i)
dir.create(file.path(paste0(base_dir,"/test_spec"), dir_name))
dir.create(file.path(paste0(base_dir,"/test_out"), dir_name))
for(j in c(1:length(pop_range))) {
for (l in c(1:repetitions)) {
params <- c(paste0("-e ", dims),
paste0("-s ", seed),
paste0("-t ", 100),
paste0("-l ", lambdar_range[i]),
paste0("--meanscale ", pop_range[i]),
paste0("--sdscale ", 1/2),
paste0("--kappa ", 1/4),
paste0("--nu ", 1/3),
paste0("--sdk ", 0.5),
paste0("-o ",base_dir,"/test_out/",dir_name, "/pop_",j,"_r_",l),
paste0("--metadata ", paste0("\"","lambdar_idx:",i,",N_idx:",j,"\"")))
fname <- paste0("pop_",j,"_r_",l,".txt")
fileConn<-file(paste0(base_dir,"/test_spec/", dir_name,"/", fname))
writeLines(params, fileConn)
close(fileConn)
seed <- seed+1
}
}
}
find ./Paper/Sim/sim_one_exp | grep .*.txt | xargs -L 1 cat | parallel -j 10 --verbose --tmpdir ./Paper/Sim/tmp -n 11 --halt-on-error 2 eval Rscript "./Scripts/simulate_tree.R {}"
find ./Paper/Sim/sim_one_exp | grep .*tree.nwk | xargs dirname | parallel -j 10 --verbose --tmpdir ./Paper/Sim/tmp --halt-on-error 2 eval Rscript "./Scripts/run_expansions_inference_nwk.R -n 1e7 -t 1e3 --lambdar 5 -s 1 -f {}/tree.nwk -o {}"
compute_ci <- function(x, conf=0.95) {
ci <- c()
x_ord <- order(x)
if(length(x)%%2==0) {
l<-length(x)/2
p1 <- x[x_ord][(l+1):length(x)]
p2 <- x[x_ord][1:l]
} else {
l <-floor(length(x)/2)
p1 <- x[x_ord][(l+2):length(x)]
p2 <- x[x_ord][1:l]
}
ci[1] <- p2[l*(1-conf)]
ci[2] <- p1[l*conf]
return(ci)
}
experiment_specs <- list.files(path = paste0(base_dir,"/test_out"), pattern=".*tree_params.json", full.names = TRUE, recursive = TRUE)
experiment_dirs <- dirname(experiment_specs)
results_data <- lapply(c(1:length(experiment_specs)), function(i){
sim_data <- fromJSON(file=experiment_specs[i])
meta_data <- sim_data$meta
sim_data <- lapply(sim_data, unlist)
sim_data <- lapply(sim_data, as.numeric)
dim.gt <- sim_data$n_exp
t_mid.gt <- sim_data$t_mid
K.gt <- sim_data$K
N.gt <- sim_data$N
time.gt <- sim_data$div_times
br.gt <- sim_data$root_set
time.gt <- time.gt[-length(time.gt)]
### Check that we got the number of expansions approximately right
expansions <- readRDS(paste0(experiment_dirs[i],"/expansions.rds"))
expansions <- discard_burn_in(expansions, proportion=burn_in)
pre <- expansions$phylo_preprocessed
mcmc_data <- expansions$model_data
event_data <- expansions$expansion_data
p_correct_dim <- length(which(mcmc_data$dim==dim.gt))/length(mcmc_data$dim)
expected_dim <- sum(mcmc_data$dim)/length(mcmc_data$dim)
unique.dim <- unique(mcmc_data$dim)
mode_dim <- unique.dim[which.max(sapply(unique.dim, function(x) length(which(mcmc_data$dim == x))))]
### take the one dimensional marginal
correct_dim <- mcmc_data[which(mcmc_data$dim==dim.gt),]
correct_dim_it <- correct_dim$it
event_dim_marginal <- event_data[unlist(sapply(correct_dim_it, function (x) which(event_data$it==x))),]
if(nrow(event_dim_marginal) > 0) {
### Get branch mode, expected t_mid, K, T, N values, jaccard between mode expansion and correct expansion
expected_N <- median(correct_dim$N)
ci_N <- compute_ci(correct_dim$N)
p_correct_br <- length(which(event_dim_marginal$br==br.gt))/length(event_dim_marginal$br)
unique.br <- unique(event_dim_marginal$br)
mode_branch <- unique.br[which.max(sapply(unique.br, function(x) length(which(event_dim_marginal$br == x))))]
### Take mode branch marginal
is_mode_correct <- mode_branch==br.gt
mode_subs <- which(event_dim_marginal$br == mode_branch)
p_mode_br <- length(mode_subs)/length(event_dim_marginal$br)
event_br_marginal <- event_dim_marginal[mode_subs,]
expected_t_mid <- median(event_br_marginal$t_mid)
ci_t_mid <- compute_ci(event_br_marginal$t_mid)
expected_K <- median(event_br_marginal$K)
ci_K <- compute_ci(event_br_marginal$K)
expected_T <- median(-event_br_marginal$time)
ci_T <- compute_ci(-event_br_marginal$time)
### jaccard index
mrca.gt <- pre$edges.df$node.child[br.gt]
mrca.mode <- pre$edges.df$node.child[mode_branch]
gt.tips <- pre$clades.list[[mrca.gt-pre$n_tips]]$tip.label
mode.tips <- pre$clades.list[[mrca.mode-pre$n_tips]]$tip.label
intersection <- sum(sapply(gt.tips, function(x) length(which(mode.tips==x))))
jacc_dist <- 1-intersection / (length(gt.tips) + length(mode.tips) - intersection)
} else {
expected_N <- NA
expected_K <- NA
expected_t_mid <- NA
jacc_dist <- NA
expected_T <- NA
p_correct_br <- 0
p_mode_br <- 0
}
return(list(p_correct_dim=p_correct_dim,
expected_dim=expected_dim,
mode_dim=mode_dim,
mode_branch=mode_branch,
expected_N=expected_N,
ci_N_lo=ci_N[1],
ci_N_hi=ci_N[2],
expected_T=expected_T,
ci_T_lo=ci_T[1],
ci_T_hi=ci_T[2],
expected_K=expected_K,
ci_K_lo=ci_K[1],
ci_K_hi=ci_K[2],
expected_t_mid=expected_t_mid,
ci_t_mid_lo=ci_t_mid[1],
ci_t_mid_hi=ci_t_mid[2],
p_correct_br=p_correct_br,
p_mode_br=p_mode_br,
is_mode_correct=is_mode_correct,
jacc_dist=jacc_dist,
t_mid_gt=t_mid.gt,
dim_gt=dim.gt,
K_gt=K.gt,
N_gt=N.gt,
time_gt=-time.gt,
N_gt=N.gt,
br_gt=br.gt))
})
names(results_data) <- c(1:length(results_data))
data_df <- do.call(rbind.data.frame, results_data)
data_df_dim <- data_df[which(data_df$mode_dim==1),]
data_df_m <- data_df_dim #data_df_dim[which(data_df_dim$is_mode_correct),]
head(data_df)
dim_hist <- ggplot(data_df, aes(x=mode_dim))
dim_hist <- dim_hist + geom_bar(aes(y = ..prop..), stat="count") +
geom_text(aes( label = scales::percent(..prop..), y= ..prop.. ), stat= "count", vjust = -.5, size=15)
dim_hist <- dim_hist + theme_bw() + labs(x = "Posterior Maximum No. Exp.", y = "Mean Posterior") + scale_y_continuous(labels=percent, limits=c(0,1))
dim_hist <- dim_hist + theme(axis.text.x = element_text(angle = 45, hjust = 1),
text = element_text(size=35))
p_hist <- ggplot(data_df_dim, aes(x=p_correct_br))
p_hist <- p_hist + geom_histogram(aes(y = stat(count) / sum(count)), bins=40) + scale_y_continuous(labels=percent, limits=c(0,1))
p_hist <- p_hist + theme_bw() + labs(x ="Prob. Correct Branch")
p_hist <- p_hist + theme(axis.text.x = element_text(angle = 45, hjust = 1),
axis.title.y = element_blank(),
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
text = element_text(size=35))
jacc_hist <- ggplot(data_df_dim, aes(x=jacc_dist))
jacc_hist <- jacc_hist + geom_histogram(aes(y = stat(count) / sum(count)), bins=40) + scale_y_continuous(labels=percent, limits=c(0,1))
jacc_hist <- jacc_hist + theme_bw() + labs(x ="Jaccard Distance")
jacc_hist <- jacc_hist + theme(axis.text.x = element_text(angle = 45, hjust = 1),
axis.title.y = element_blank(),
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
text = element_text(size=35))
p <- ggarrange(dim_hist, p_hist, jacc_hist, widths=c(2,2,2), heights=c(1))
png("./Paper/Figures/fig3a.png", width=1600,height=800)
p
dev.off()
## png
## 2
lims <- c(min(data_df_m$ci_N_lo),max(data_df_m$ci_N_hi))
N_scatter <- ggplot(data_df_m, aes(x=N_gt, y=expected_N))
N_scatter <- N_scatter + geom_point()
N_scatter <- N_scatter + geom_errorbar(aes(ymin = ci_N_lo, ymax = ci_N_hi), width = 0.2, alpha=0.3)
N_scatter <- N_scatter + scale_x_continuous(trans="log10", limits=lims) + scale_y_continuous(trans="log10", limits=lims)
N_scatter <- N_scatter + theme_bw() + labs(x ="True Background Population", y="Posterior Median Background Population") + coord_fixed(ratio = 1)
N_scatter <- N_scatter + theme(axis.text.x = element_text(angle = 45, hjust = 1), text = element_text(size=35))
N_scatter <- N_scatter + geom_abline(intercept = 0, slope = 1)
lims <- c(min(data_df_m$ci_K_lo),max(data_df_m$ci_K_hi))
K_scatter <- ggplot(data_df_m, aes(x=K_gt, y=expected_K))
K_scatter <- K_scatter + geom_point()
K_scatter <- K_scatter + geom_errorbar(aes(ymin = ci_K_lo, ymax = ci_K_hi), width = 0.2, alpha=0.3)
K_scatter <- K_scatter + scale_x_continuous(trans="log10", limits=lims) + scale_y_continuous(trans="log10", limits=lims)
K_scatter <- K_scatter + theme_bw() + labs(x ="True Carrying Capacity", y="Posterior Median Carrying Capacity") + coord_fixed(ratio = 1)
K_scatter <- K_scatter + theme(axis.text.x = element_text(angle = 45, hjust = 1), text = element_text(size=35))
K_scatter <- K_scatter + geom_abline(intercept = 0, slope = 1)
lims <- c(min(data_df_m$ci_T_lo),max(data_df_m$ci_T_hi))
T_scatter <- ggplot(data_df_m, aes(x=time_gt, y=expected_T))
T_scatter <- T_scatter + geom_point()
T_scatter <- T_scatter + geom_errorbar(aes(ymin = ci_T_lo, ymax = ci_T_hi), width = 0.2, alpha=0.3)
T_scatter <- T_scatter + scale_x_continuous(trans="log10", limits=lims) + scale_y_continuous(trans="log10", limits=lims)
T_scatter <- T_scatter + theme_bw() + labs(x ="True Time of Expansions", y="Posterior Median Time of Expansions") + coord_fixed(ratio = 1)
T_scatter <- T_scatter + theme(axis.text.x = element_text(angle = 45, hjust = 1), text = element_text(size=35))
T_scatter <- T_scatter + geom_abline(intercept = 0, slope = 1)
lims <- c(min(data_df_m$ci_t_mid_lo),max(data_df_m$ci_t_mid_hi))
t_mid_scatter <- ggplot(data_df_m, aes(x=t_mid_gt, y=expected_t_mid))
t_mid_scatter <- t_mid_scatter + geom_point()
t_mid_scatter <- t_mid_scatter + geom_errorbar(aes(ymin = ci_t_mid_lo, ymax = ci_t_mid_hi), width = 0.2, alpha=0.3)
t_mid_scatter <- t_mid_scatter + scale_x_continuous(trans="log10", limits=lims) + scale_y_continuous(trans="log10", limits=lims)
t_mid_scatter <- t_mid_scatter + theme_bw() + labs(x ="True Time to Half Capacity", y="Posterior Median Time to Half Capacity") + coord_fixed(ratio = 1)
t_mid_scatter <- t_mid_scatter + theme(axis.text.x = element_text(angle = 45, hjust = 1), text = element_text(size=35))
t_mid_scatter <- t_mid_scatter + geom_abline(intercept = 0, slope = 1)
p <- ggarrange(N_scatter, K_scatter, t_mid_scatter, T_scatter, widths=c(2,2), heights=c(2,2))
png("./Paper/Figures/fig3b.png", width=1600,height=1600)
p
dev.off()
## png
## 2